import os
import glob
import torch
import traceback
# from ltr.admin import loading, multigpu


class BaseTrainer:
    """Base trainer class. Contains functions for training and saving/loading chackpoints.
    Trainer classes should inherit from this one and overload the train_epoch function."""

    def __init__(self, actor, loaders, optimizer, settings, lr_scheduler=None):
        """
        args:
            actor - The actor for training the network
            loaders - list of dataset loaders, e.g. [train_loader, val_loader]. In each epoch, the trainer runs one
                        epoch for each loader.
            optimizer - The optimizer used for training, e.g. Adam
            settings - Training settings
            lr_scheduler - Learning rate scheduler
        """
        self.actor = actor
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.loaders = loaders

        # self.update_settings(settings)

        self.epoch = 0
        self.stats = {}

        self.device = getattr(settings, 'device', None)
        if self.device is None:
            self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        self.actor.to(self.device)

        self.settings = settings

    # def update_settings(self, settings=None):
    #     """Updates the trainer settings. Must be called to update internal settings."""
    #     if settings is not None:
    #         self.settings = settings
    #
    #     if self.settings.env.workspace_dir is not None:
    #         self.settings.env.workspace_dir = os.path.expanduser(self.settings.env.workspace_dir)
    #         self._checkpoint_dir = os.path.join(self.settings.env.workspace_dir, 'checkpoints')
    #         if not os.path.exists(self._checkpoint_dir):
    #             os.makedirs(self._checkpoint_dir)
    #     else:
    #         self._checkpoint_dir = None


    def train(self, max_epochs, load_latest=False, fail_safe=True):
        """Do training for the given number of epochs.
        args:
            max_epochs - Max number of training epochs,
            load_latest - Bool indicating whether to resume from latest epoch.
            fail_safe - Bool indicating whether the training to automatically restart in case of any crashes.
        """

        epoch = -1
        num_tries = 10
        # for i in range(num_tries):
            # try:
        # if load_latest:
        #     self.load_checkpoint()

        for epoch in range(self.epoch+1, max_epochs+1):
            self.epoch = epoch

            self.train_epoch()

            if self.lr_scheduler is not None:
                self.lr_scheduler.step()

            print('save epoch {:d}'.format(epoch))
            self.save_model(self.actor.net, epoch, self.optimizer, self.settings.ONLINE.TRAIN.MODEL, self.settings, isbest=False)


            # if self._checkpoint_dir:
            #     self.save_checkpoint()
            # except:
            #     print('Training crashed at epoch {}'.format(epoch))
            #     if fail_safe:
            #         self.epoch -= 1
            #         load_latest = True
            #         print('Traceback for the error!')
            #         print(traceback.format_exc())
            #         print('Restarting training from last epoch ...')
            #     else:
            #         raise

        print('Finished training!')


    def train_epoch(self):
        raise NotImplementedError

    def save_model(self, model, epoch, optimizer, model_name, cfg, isbest=False):
        """
        save model
        """
        if not os.path.exists(cfg.CHECKPOINT_DIR):
            os.makedirs(cfg.CHECKPOINT_DIR)

        if epoch > 0:
            self.save_checkpoint({
                'epoch': epoch + 1,
                'arch': model_name,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }, isbest, cfg.CHECKPOINT_DIR, 'checkpoint_e%d.pth' % (epoch + 1))
        else:
            print('epoch not save(<5)')

    def save_checkpoint(self, states, is_best, output_dir, filename='checkpoint.pth.tar'):
        """
        save checkpoint
        """
        torch.save(states, os.path.join(output_dir, filename))
        if is_best and 'state_dict' in states:
            torch.save(states['state_dict'],
                       os.path.join(output_dir, 'model_best.pth'))


    # def save_checkpoint(self):
    #     """Saves a checkpoint of the network and other variables."""
    #
    #     net = self.actor.net.module if multigpu.is_multi_gpu(self.actor.net) else self.actor.net
    #
    #     actor_type = type(self.actor).__name__
    #     net_type = type(net).__name__
    #     state = {
    #         'epoch': self.epoch,
    #         'actor_type': actor_type,
    #         'net_type': net_type,
    #         'net': net.state_dict(),
    #         'net_info': getattr(net, 'info', None),
    #         'constructor': getattr(net, 'constructor', None),
    #         'optimizer': self.optimizer.state_dict(),
    #         'stats': self.stats,
    #         'settings': self.settings
    #     }
    #
    #
    #     directory = '{}/{}'.format(self._checkpoint_dir, self.settings.project_path)
    #     if not os.path.exists(directory):
    #         os.makedirs(directory)
    #
    #     file_path = '{}/{}_ep{:04d}.pth.tar'.format(directory, net_type, self.epoch)
    #     torch.save(state, file_path)


    # def load_checkpoint(self, checkpoint = None, fields = None, ignore_fields = None, load_constructor = False):
    #     """Loads a network checkpoint file.
    #
    #     Can be called in three different ways:
    #         load_checkpoint():
    #             Loads the latest epoch from the workspace. Use this to continue training.
    #         load_checkpoint(epoch_num):
    #             Loads the network at the given epoch number (int).
    #         load_checkpoint(path_to_checkpoint):
    #             Loads the file from the given absolute path (str).
    #     """
    #
    #     net = self.actor.net.module if multigpu.is_multi_gpu(self.actor.net) else self.actor.net
    #
    #     actor_type = type(self.actor).__name__
    #     net_type = type(net).__name__
    #
    #     if checkpoint is None:
    #         # Load most recent checkpoint
    #         checkpoint_list = sorted(glob.glob('{}/{}/{}_ep*.pth.tar'.format(self._checkpoint_dir,
    #                                                                          self.settings.project_path, net_type)))
    #         if checkpoint_list:
    #             checkpoint_path = checkpoint_list[-1]
    #         else:
    #             print('No matching checkpoint file found')
    #             return
    #     elif isinstance(checkpoint, int):
    #         # Checkpoint is the epoch number
    #         checkpoint_path = '{}/{}/{}_ep{:04d}.pth.tar'.format(self._checkpoint_dir, self.settings.project_path,
    #                                                              net_type, checkpoint)
    #     elif isinstance(checkpoint, str):
    #         # checkpoint is the path
    #         if os.path.isdir(checkpoint):
    #             checkpoint_list = sorted(glob.glob('{}/*_ep*.pth.tar'.format(checkpoint)))
    #             if checkpoint_list:
    #                 checkpoint_path = checkpoint_list[-1]
    #             else:
    #                 raise Exception('No checkpoint found')
    #         else:
    #             checkpoint_path = os.path.expanduser(checkpoint)
    #     else:
    #         raise TypeError
    #
    #     # Load network
    #     checkpoint_dict = loading.torch_load_legacy(checkpoint_path)
    #
    #     assert net_type == checkpoint_dict['net_type'], 'Network is not of correct type.'
    #
    #     if fields is None:
    #         fields = checkpoint_dict.keys()
    #     if ignore_fields is None:
    #         ignore_fields = ['settings']
    #
    #         # Never load the scheduler. It exists in older checkpoints.
    #     ignore_fields.extend(['lr_scheduler', 'constructor', 'net_type', 'actor_type', 'net_info'])
    #
    #     # Load all fields
    #     for key in fields:
    #         if key in ignore_fields:
    #             continue
    #         if key == 'net':
    #             net.load_state_dict(checkpoint_dict[key])
    #         elif key == 'optimizer':
    #             self.optimizer.load_state_dict(checkpoint_dict[key])
    #         else:
    #             setattr(self, key, checkpoint_dict[key])
    #
    #     # Set the net info
    #     if load_constructor and 'constructor' in checkpoint_dict and checkpoint_dict['constructor'] is not None:
    #         net.constructor = checkpoint_dict['constructor']
    #     if 'net_info' in checkpoint_dict and checkpoint_dict['net_info'] is not None:
    #         net.info = checkpoint_dict['net_info']
    #
    #     # Update the epoch in lr scheduler
    #     if 'epoch' in fields:
    #         self.lr_scheduler.last_epoch = self.epoch
    #
    #     return True
